import os
from marker.scripts.common import (
    load_models,
    parse_args,
    img_to_html,
    get_page_image,
    page_count,
)
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["IN_STREAMLIT"] = "true"
from marker.settings import settings
from streamlit.runtime.uploaded_file_manager import UploadedFile
import re
import tempfile
from typing import Any, Dict
import streamlit as st
from PIL import Image
from marker.config.parser import ConfigParser
from loguru import logger

from converters.pdf import PdfConverter
from renderers.output import text_from_rendered

def convert_pdf(fname: str, config_parser: ConfigParser) -> (str, Dict[str, Any], dict):
    config_dict = config_parser.generate_config_dict()
    config_dict["pdftext_workers"] = 1
    converter_cls = PdfConverter
    converter = converter_cls(
        config=config_dict,
        artifact_dict=model_dict,
        processor_list=config_parser.get_processors(),
        renderer=config_parser.get_renderer(),
        llm_service=config_parser.get_llm_service(),
    )
    return converter(fname)


def markdown_insert_images(markdown, images):
    image_tags = re.findall(
        r'(!\[(?P<image_title>[^\]]*)\]\((?P<image_path>[^\)"\s]+)\s*([^\)]*)\))',
        markdown,
    )

    for image in image_tags:
        image_markdown = image[0]
        image_alt = image[1]
        image_path = image[2]
        if image_path in images:
            markdown = markdown.replace(
                image_markdown, img_to_html(images[image_path], image_alt)
            )
    return markdown

def markdown_remove_images(markdown):
    markdown = re.sub(
        r'(!\[(?P<image_title>[^\]]*)\]\((?P<image_path>[^\)"\s]+)\s*([^\)]*)\))',
        '',
        markdown,
    )
    return markdown


st.set_page_config(layout="wide")
model_dict = load_models()
cli_options = parse_args()

in_file: UploadedFile = st.sidebar.file_uploader(
    "PDF, document, or image file:",
    type=["pdf", "png", "jpg", "jpeg", "gif", "pptx", "docx", "xlsx", "html", "epub"],
)
if in_file is None:
    st.stop()

logger.info(f'解析文档：{in_file.name}')
filetype = in_file.type

page_count = page_count(in_file)
page_number = st.number_input(
    f"Page number out of {page_count}:", min_value=0, value=0, max_value=page_count
)
col1, col2 = st.columns([0.3, 0.7])
with col1:
    pil_image = get_page_image(in_file, page_number)
    st.image(pil_image, use_container_width=True)

page_range = st.sidebar.text_input(
    "Page range to parse, comma separated like 0,5-10,20",
    value=f"{page_number}-{page_number}",
)
logger.info(f'页面选择：{page_range}')
output_format = st.sidebar.selectbox(
    "Output format", ["markdown", "json", "html", "chunks"], index=0
)
run_marker = st.sidebar.button("Run Marker")

use_llm = st.sidebar.checkbox(
    "Use LLM", help="Use LLM for higher quality processing", value=False
)
force_ocr = st.sidebar.checkbox("Force OCR", help="Force OCR on all pages", value=False)
strip_existing_ocr = st.sidebar.checkbox(
    "Strip existing OCR",
    help="Strip existing OCR text from the PDF and re-OCR.",
    value=False,
)
debug = st.sidebar.checkbox("Debug", help="Show debug information", value=False)
format_lines = st.sidebar.checkbox(
    "Format lines",
    help="Format lines in the document with OCR model",
    value=False,
)
disable_ocr_math = st.sidebar.checkbox(
    "Disable math",
    help="Disable math in OCR output - no inline math",
    value=False,
)

if not run_marker:
    st.stop()

# Run Marker
with tempfile.TemporaryDirectory() as tmp_dir:
    temp_pdf = os.path.join(tmp_dir, "temp.pdf")
    with open(temp_pdf, "wb") as f:
        f.write(in_file.getvalue())

    cli_options.update(
        {
            "output_format": output_format,
            "page_range": page_range,
            "force_ocr": force_ocr,
            "debug": debug,
            "output_dir": settings.DEBUG_DATA_FOLDER if debug else None,
            "use_llm": use_llm,
            "strip_existing_ocr": strip_existing_ocr,
            "format_lines": format_lines,
            "disable_ocr_math": disable_ocr_math,
        }
    )
    config_parser = ConfigParser(cli_options)
    rendered = convert_pdf(temp_pdf, config_parser)
    page_range = config_parser.generate_config_dict()["page_range"]
    first_page = page_range[0] if page_range else 0

text, ext, images = text_from_rendered(rendered)
with col2:
    if output_format == "markdown":
        # text = markdown_insert_images(text, images)
        text = markdown_remove_images(text)
        logger.info(f'解析结果：\n{text}')
        st.markdown(text, unsafe_allow_html=True)
    elif output_format == "json":
        st.json(text)
    elif output_format == "html":
        st.html(text)
    elif output_format == "chunks":
        st.json(text)

st.text_area(label='markdown文本', value=text, key='text', height=500)

if debug:
    with col1:
        debug_data_path = rendered.metadata.get("debug_data_path")
        if debug_data_path:
            pdf_image_path = os.path.join(debug_data_path, f"pdf_page_{first_page}.png")
            img = Image.open(pdf_image_path)
            st.image(img, caption="PDF debug image", use_container_width=True)
            layout_image_path = os.path.join(
                debug_data_path, f"layout_page_{first_page}.png"
            )
            img = Image.open(layout_image_path)
            st.image(img, caption="Layout debug image", use_container_width=True)
        st.write("Raw output:")
        st.code(text, language=output_format)
